from .base_reasoner import BaseReasoner, ReasoningNode
import asyncio
import argparse
import json
import os
import re
import time
import traceback
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Any, Union
import random
import openai
from collections import defaultdict
from datetime import datetime

class Gsm8kReasoner(BaseReasoner):
    def __init__(self):
        super().__init__("gsm8k")
        self.config.dataset_path = "datasets/gsm8k.json"
    
    async def load_problems(self, start_idx: int, end_idx: int) -> List[Dict]:
        """Load math problems from dataset"""
        try:
            with open(self.config.dataset_path, "r", encoding="utf-8") as f:
                data = json.load(f)
                return data[start_idx:end_idx]
        except Exception as e:
            print(f"Error loading dataset: {str(e)}")
            return []

    async def execute_workflow(self, question: str) -> Dict[str, Any]:
        """Execute full prompt engineering workflow"""
        try:
            # Step 0: Extract problem constraints
            constraints = await self._extract_constraints(question)
            if not isinstance(constraints, dict):
                constraints = {
                    "explicit": [],
                    "implicit": [],
                    "notes": "Invalid constraints format"
                }
            self._log_step("step0", "system", {"constraints": constraints})
            
            # Step1: Create root node with initial constraints
            root = self._create_node(
                path=[], 
                method={"description": "Original problem"}, 
                steps=[], 
                constraints={
                    "explicit": constraints.get("explicit", []),
                    "implicit": constraints.get("implicit", [])
                },
                question=question
            )
            self._log_step("step1", root.node_id, {"question": question})
            
            # Step4: Check if classification needed (directly after root creation)
            classification_result = await self._check_classification(
                root.method["description"],
                root.steps
            )
            self._log_step("step4", root.node_id, classification_result)
            
            if classification_result["need_classify"]:
                # Step5: Create classification nodes with combined constraints
                for case in classification_result["cases"]:
                    # Merge parent constraints with case constraints
                    combined_constraints = {
                        "explicit": root.constraints["explicit"].copy(),
                        "implicit": root.constraints["implicit"].copy()
                    }
                    # Add case-specific constraints
                    for k, v in case["constraints"].items():
                        if k in combined_constraints:
                            combined_constraints[k].append(v)
                        else:
                            combined_constraints["implicit"].append(f"{k}: {v}")
                    
                    node = self._create_node(
                        path=[root.node_id],
                        method=root.method,
                        steps=root.steps,
                        score=0,
                        constraints=combined_constraints,
                        parent_id=root.node_id,
                        question=question
                    )
                    root.children.append(node.node_id)
                    self.temp_list.append(node.node_id)
                    self._log_step("step5", node.node_id, {
                        "case": case,
                        "combined_constraints": combined_constraints
                    })
            else:
                self.temp_list.append(root.node_id)
            
            # Step6: Build temporary list (handled in step5)
            self._log_step("step6", "system", {"temp_list": self.temp_list})
            
            # Step7: Solve nodes iteratively
            solutions = []
            for node_id in self.temp_list:
                solution = await self._solve_node(node_id)
                if solution:
                    solutions.append(solution)
                    self._log_step("step7", node_id, {"solution": solution})
            
            # Step8: Aggregate answers
            final_answer = await self._aggregate_answers(solutions)
            self._log_step("step8", "system", {"final_answer": final_answer})
            
            return {
                "status": "success",
                "final_answer": final_answer,
                "nodes": self.nodes,
                "logs": self.logs,
                "token_usage": self.llm.token_counts
            }
            
        except Exception as e:
            traceback.print_exc()
            return {
                "status": "error",
                "message": str(e),
                "logs": self.logs
            }
        
    async def _extract_constraints(self, question: str) -> Dict[str, Any]:
        """Extract explicit and implicit constraints from problem"""
        prompt = f"""You are a world-class mathematician and mathematical logician.  
    You are intelligent, rigorous, and cautious.  
    You always reason step by step, consider all relevant constraints.  
    You think in terms of structure, symmetry, and mathematical principles, and never skip important logical steps.  
    You aim to find a complete and correct solution, not just an answer.  
    You THINK CLEARLY, STRUCTURALLY, AND DEEPLY. 
    Analyze this math problem and extract ALL constraints:
    
    problem:{question}
    
    Notice:
    1. Identify explicit constraints (directly stated in the problem)
    2. Derive implicit constraints (e.g., denominators ≠ 0, square roots ≥ 0, log arguments > 0)
    3. Determine domain restrictions based on mathematical principles
    4. Identify range limitations from problem context
    5. Extract physical meaning constraints (e.g., length > 0, probability ∈ [0,1])
    
    Output JSON format:
    {{
        "explicit": ["constraint1", "constraint2"],
        "implicit": ["constraint1", "constraint2"],
        "notes": "Additional analysis notes"
    }}"""
    
        for attempt in range(self.config.max_retries):
            try:
                response = await self.llm.generate(prompt, response_format="json_object")
                data = json.loads(response)
                
                if not isinstance(data, dict):
                    print(f"Invalid response type (attempt {attempt+1}): {type(data)}")
                    continue
                    
                constraints = {
                    "explicit": data.get("explicit", []),
                    "implicit": data.get("implicit", []),
                    "notes": data.get("notes", "")
                }
                
                if not (constraints["explicit"] or constraints["implicit"]):
                    print(f"Empty constraints (attempt {attempt+1})")
                    continue
                    
                return constraints
                
            except (json.JSONDecodeError, AttributeError) as e:
                print(f"Parse error (attempt {attempt+1}): {str(e)}")
                continue
        
        print("All retries failed, returning default constraints")
        return {
            "explicit": ["Default explicit constraint"],
            "implicit": ["Default implicit constraint"],
            "notes": "Fallback constraints"
        }
    
    async def _check_classification(self, method: str, steps: List[str]) -> Dict[str, Any]:
        """Determine if classification needed"""
        prompt = f"""You are a world-class mathematician and mathematical logician.  
    You are intelligent, rigorous, and cautious.  
    You always reason step by step, consider all relevant constraints.  
    You think in terms of structure, symmetry, and mathematical principles, and never skip important logical steps.  
    You aim to find a complete and correct solution, not just an answer.  
    You THINK CLEARLY, STRUCTURALLY, AND DEEPLY. 
    Determine if this solution requires classification:

Method: {method}
Steps: {steps}

Notice:
1. Identify parameter dependencies requiring discussion
2. Detect interval-specific elements (absolute values, piecewise functions)
3. Recognize domain-specific computation requirements
4. Flag multiple solution sets needing verification
5. Pay attention to the mathematical expressions in the questions and understand them correctly
6. examine carefully the subject matter

If classification needed, provide:
- Comprehensive case descriptions
- Precise mathematical constraints for each case
- Clear boundary conditions

Output JSON format:
{{
    "need_classify": true/false,
    "reason": "Classification rationale",
    "cases": [
        {{
            "description": "Case description",
            "constraints": {{"parameter": "value_range"}}
        }}
    ]
}}"""
        
        response = await self.llm.generate(prompt, response_format="json_object")
        try:
            data = json.loads(response)
            return {
                "need_classify": data.get("need_classify", False),
                "reason": data.get("reason", ""),
                "cases": data.get("cases", [])
            }
        except json.JSONDecodeError:
            print(f"Failed to parse classification response: {response}")
            return {"need_classify": False, "reason": "Parse failed", "cases": []}
    
    async def _solve_node(self, node_id: str) -> Optional[Dict[str, Any]]:
        """Solve individual node"""
        node = self.nodes[node_id]
        root_node = self.nodes[node.path[0]] if node.path else node
        original_question = getattr(node, 'question', None) or getattr(root_node, 'question', "Original problem")
        
        prompt = f"""You are a world-class mathematician and mathematical logician.  
    You are intelligent, rigorous, and cautious.  
    You always reason step by step, consider all relevant constraints.  
    You think in terms of structure, symmetry, and mathematical principles, and never skip important logical steps.  
    You aim to find a complete and correct solution, not just an answer.  
    You THINK CLEARLY, STRUCTURALLY, AND DEEPLY. 
    You are a meticulous mathematical problem solver executing this solution:
    
    Original Problem: {original_question}
    Steps: {node.steps}
    Constraints: {node.constraints}
    
    As an executor, you must:
    1. Follow the provided steps precisely
    2. Explicitly verify all constraints at each step
    3. Show complete mathematical justification
    4. Use proper mathematical notation
    5. Clearly mark the final answer with \\boxed{{}}
    6. Include standalone line: "Final Answer: answer"
    7. Ensure your answer directly responds to the question asked
    8. The final answer should be one exact number
    9. Pay attention to the mathematical expressions in the questions and understand them correctly
    10. examine carefully the subject matter
    
    Additional requirements:
    - Show all intermediate calculations
    - State any assumptions made
    - Verify solution satisfies all constraints
    - Cross-validate critical steps
    - If the question asks for GCD, provide only the GCD as final answer
    - If you calculate intermediate values (like A and B), clearly distinguish them from the final answer"""
        
        response = await self.llm.generate(prompt)
        answer = self._extract_answer(response)
        
        if answer:
            node.answer = answer
            node.state = "solved"
            return {
                "node_id": node_id,
                "response": response,
                "answer": answer
            }
        return None
    
    async def _aggregate_answers(self, solutions: List[Dict[str, Any]]) -> str:
        """Aggregate solutions with original question"""
        if not solutions:
            return "No valid solutions found"
        
        original_question = None
        for sol in solutions:
            node = self.nodes[sol["node_id"]]
            if hasattr(node, 'original_question'):
                original_question = node.original_question
                break
        
        if original_question is None:
            first_node = self.nodes[solutions[0]["node_id"]]
            path = first_node.path
            if path: 
                root_node_id = path[0]
                root_node = self.nodes.get(root_node_id)
                if root_node:
                    original_question = root_node.method.get("description", "Original problem")
        
        if original_question is None:
            original_question = "Original problem (reconstructed from context)"
            if solutions[0]["response"]:
                match = re.search(r'Original Problem[:\s]*(.+?)\nSteps:', solutions[0]["response"])
                if match:
                    original_question = match.group(1).strip()
        
        if len(solutions) == 1:
            return solutions[0]["answer"]
        
        unique_answers = {sol["answer"] for sol in solutions}
        if len(unique_answers) == 1:
            return solutions[0]["answer"]
        
        solutions_text = "\n\n".join(
            f"Solution {i+1} (Node: {sol['node_id']}):\n"
            f"Answer: {sol['answer']}\n"
            f"Approach: {self.nodes[sol['node_id']].method['description']}\n"
            f"Constraints: {self.nodes[sol['node_id']].constraints}\n"
            f"Reasoning Excerpt:\n{sol['response'][:300]}...\n"
            for i, sol in enumerate(solutions)
        )
        
        prompt = f"""You are a world-class mathematician and mathematical logician.  
    You are intelligent, rigorous, and cautious.  
    You always reason step by step, consider all relevant constraints.  
    You think in terms of structure, symmetry, and mathematical principles, and never skip important logical steps.  
    You aim to find a complete and correct solution, not just an answer.  
    You THINK CLEARLY, STRUCTURALLY, AND DEEPLY. 
    Synthesize these solutions for the original problem:
    
    Original Problem: {original_question}
    
    Proposed Solutions:
    {solutions_text}
    
    As an analyst, you must:
    1. FIRST verify which solution(s) correctly answer the original question
    2. Compare mathematical consistency with the original problem statement
    3. Evaluate which approach best satisfies all constraints
    4. Combine elements from multiple solutions ONLY if mathematically valid
    5. Provide clear justification for your selection
    6. Mark final answer with \\boxed{{}}
    7. Include standalone line: "Aggregated Answer: answer"
    
    Critical Analysis Guidelines:
    - The solution MUST directly answer the original question as stated
    - Prioritize mathematical correctness over elegance
    - Reject solutions that violate any explicit constraints
    - Verify all intermediate calculations are sound
    - Ensure the final answer format matches what the problem requires"""
    
        response = await self.llm.generate(prompt)
        return self._extract_answer(response) or "Aggregation failed"
    
    def _extract_answer(self, text: str) -> Optional[str]:
        """Extract answer from response text"""
        boxed_pattern = r'\\boxed\{([^{}]+)\}' 
        boxed_matches = re.findall(boxed_pattern, text)
        if boxed_matches:
            return boxed_matches[-1] 
    
        final_answer_match = re.search(
            r'Final\s+Answer\s*:\s*([^\n]+)', 
            text, 
            re.IGNORECASE
        )
        if final_answer_match:
            return final_answer_match.group(1).strip()
    
        return None
      
        
    def save_results(self, result: Dict[str, Any], problem: Dict[str, Any]) -> Dict[str, Any]:
        if "nodes" in result:
            del result["nodes"]    
        
        # Prepare verification info
        verification = {
            "is_correct": False,
            "correct_answer": None,
            "given_answer": result.get("final_answer")
        }
        
        if "answer" in problem:
            correct_answer = None
            if "solution" in problem:
                correct_answer = self._extract_correct_answer(problem["solution"])
            elif "answer" in problem:
                correct_answer = self._extract_correct_answer(problem["answer"])

            verification["correct_answer"] = correct_answer
            
            if correct_answer is not None and "final_answer" in result:
                given = str(result["final_answer"]).strip()
                expected = str(correct_answer).strip()
                
                if len(expected) == 1 and given.endswith(expected):
                    verification["is_correct"] = True
                else:
                    try:
                        given_num = float(given)
                        expected_num = float(expected)
                        if abs(given_num - expected_num) < 1e-10:  # Allow tiny floating-point differences
                            verification["is_correct"] = True
                    except ValueError:
                        pass  # Not a number, keep is_correct=False
        
        return {
                "question": problem["question"],
                "result": result,
                "verification": verification
        }
    
    def _extract_correct_answer(self, solution: str) -> Optional[str]:
        """Extract correct answer from solution's \boxed{}"""
        hash_pattern = r'####\s*([^\n]+)'
        hash_matches = re.findall(hash_pattern, solution)
        return hash_matches[-1].strip() if hash_matches else None
    
    async def verify_answer(self, problem: Dict[str, Any], final_answer: str) -> bool:
        """Verify if final answer matches correct solution"""
        if "solution" not in problem:
            return False
            
        correct_answer = self._extract_correct_answer(problem["solution"])
        if not correct_answer:
            return False
            
        # Simple string comparison (could be enhanced for numeric tolerance)
        return str(final_answer).strip() == str(correct_answer).strip()